ANN with MNIST


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

1. What's an MNIST?

From Wikipedia

  • The MNIST database (Mixed National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning. It was created by "re-mixing" the samples from NIST's original datasets. The creators felt that since NIST's training dataset was taken from American Census Bureau employees, while the testing dataset was taken from American high school students, NIST's complete dataset was too hard.
  • MNIST (Mixed National Institute of Standards and Technology database) database
    • Handwritten digit database
    • $28 \times 28$ gray scaled image
    • Flattened matrix into a vector of $28 \times 28 = 784$



More here

We will be using MNIST to create a Multinomial Classifier that can detect if the MNIST image shown is a member of class 0,1,2,3,4,5,6,7,8 or 9. Susinctly, we're teaching a computer to recognize hand written digets.

In [1]:
# Import Library
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
%matplotlib inline

Let's download and load the dataset.

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
WARNING:tensorflow:From <ipython-input-2-8bf8ae5a5303>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
In [3]:
print ("The training data set is:\n")
print (mnist.train.images.shape)
print (mnist.train.labels.shape)
The training data set is:

(55000, 784)
(55000, 10)
In [4]:
print ("The test data set is:")
print (mnist.test.images.shape)
print (mnist.test.labels.shape)
The test data set is:
(10000, 784)
(10000, 10)

Display a few random samples from it:

In [5]:
mnist.train.images[5]
Out[5]:
array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.23137257,  0.63921571,  0.99607849,  0.99607849,  0.99607849,
        0.76078439,  0.43921572,  0.07058824,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.01568628,  0.51764709,  0.93725497,  0.99215692,
        0.99215692,  0.99215692,  0.99215692,  0.99607849,  0.99215692,
        0.627451  ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.53725493,
        0.99215692,  0.99607849,  0.99215692,  0.99215692,  0.99215692,
        0.75294125,  0.99607849,  0.99215692,  0.89803928,  0.0509804 ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.01568628,  0.53725493,  0.98431379,  0.99215692,  0.95686281,
        0.50980395,  0.19215688,  0.07450981,  0.01960784,  0.63921571,
        0.99215692,  0.82352948,  0.03529412,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.37254903,  0.99215692,
        0.99215692,  0.84313732,  0.17647059,  0.        ,  0.        ,
        0.        ,  0.        ,  0.61176473,  0.99215692,  0.68235296,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.84313732,  0.99607849,  0.81176478,  0.09019608,
        0.        ,  0.        ,  0.        ,  0.03921569,  0.38039219,
        0.85098046,  0.91764712,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.83921576,
        0.99215692,  0.27843139,  0.        ,  0.        ,  0.00784314,
        0.19607845,  0.83529419,  0.99215692,  0.99607849,  0.70588237,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.83921576,  0.99215692,  0.19215688,
        0.        ,  0.        ,  0.19607845,  0.99215692,  0.99215692,
        0.99215692,  0.71764708,  0.04705883,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.78039223,  0.99215692,  0.95294124,  0.76862752,  0.62352943,
        0.95294124,  0.99215692,  0.96862751,  0.5411765 ,  0.03137255,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.16470589,  0.99215692,
        0.99215692,  0.99215692,  0.99607849,  0.99215692,  0.99215692,
        0.39607847,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.23137257,  0.58431375,  0.99607849,  0.99607849,  0.99607849,
        1.        ,  0.99607849,  0.68627453,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.13333334,  0.75294125,  0.99607849,  0.99215692,
        0.99215692,  0.99215692,  0.7843138 ,  0.53333336,  0.89019614,
        0.9450981 ,  0.27058825,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.33333334,  0.96862751,
        0.99215692,  0.99607849,  0.99215692,  0.77647066,  0.48235297,
        0.07058824,  0.        ,  0.19607845,  0.99215692,  0.83529419,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.27843139,  0.96862751,  0.99215692,  0.92941183,  0.75294125,
        0.27843139,  0.02352941,  0.        ,  0.        ,  0.        ,
        0.00784314,  0.50196081,  0.98039222,  0.21176472,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.46274513,  0.99215692,
        0.8705883 ,  0.14117648,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.03137255,  0.71764708,
        0.99215692,  0.227451  ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.46274513,  0.99607849,  0.54509807,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.05490196,  0.72941178,  0.99607849,  0.99607849,  0.227451  ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.27843139,
        0.96862751,  0.96862751,  0.54509807,  0.0627451 ,  0.        ,
        0.        ,  0.07450981,  0.227451  ,  0.87843144,  0.99215692,
        0.99215692,  0.83137262,  0.03529412,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.42352945,  0.99215692,
        0.99215692,  0.92549026,  0.68627453,  0.68627453,  0.96862751,
        0.99215692,  0.99607849,  0.99215692,  0.77647066,  0.16862746,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.26274511,  0.83529419,  0.89803928,  0.99607849,
        0.99215692,  0.99215692,  0.99215692,  0.99215692,  0.83921576,
        0.48627454,  0.02352941,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.09019608,  0.60784316,  0.60784316,  0.87450987,
        0.7843138 ,  0.46274513,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32)
In [6]:
# well, that's not a picture (or image), it's an array.

mnist.train.images[5].shape
Out[6]:
(784,)

You might think the training set is made up of 28 $\times$28 grayscale images of handwritten digits. No !!!

The thing is, the iamge has been flattened. These are 28x28 images that have been flattened into a 1D array. Let's reshape one.

In [7]:
img = np.reshape(mnist.train.images[5], [28,28])
In [8]:
img = mnist.train.images[5].reshape([28,28])
In [9]:
# So now we have a 28x28 matrix, where each element is an intensity level from 0 to 1.  
img.shape
Out[9]:
(28, 28)

Let's visualize what some of these images and their corresponding training labels look like.

In [10]:
plt.figure(figsize = (6,6))
plt.imshow(img, 'gray')
plt.xticks([])
plt.yticks([])
plt.show()
In [11]:
mnist.train.labels[5]
Out[11]:
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.])
In [12]:
np.argmax(mnist.train.labels[5])
Out[12]:
8

Batch maker embedded

In [13]:
x, y = mnist.train.next_batch(3)

print(x.shape)
print(y.shape)
(3, 784)
(3, 10)

2. ANN with TensorFlow

  • Feed a gray image to ANN


  • Our network model



- Network training (learning) $$\omega:= \omega - \alpha \nabla_{\omega} \left( h_{\omega} \left(x^{(i)}\right),y^{(i)}\right)$$

2.1. Import Library

In [14]:
# Import Library
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

2.2. Load MNIST Data

  • Download MNIST data from tensorflow tutorial example
In [15]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [16]:
train_x, train_y = mnist.train.next_batch(1)
img = train_x[0,:].reshape(28,28)

plt.figure(figsize=(6,6))
plt.imshow(img,'gray')
plt.title("Label : {}".format(np.argmax(train_y[0,:])))
plt.xticks([])
plt.yticks([])
plt.show()

One hot encoding

In [17]:
print ('Train labels : {}'.format(train_y[0, :]))
Train labels : [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]

2.3. Define an ANN Structure

  • Input size
  • Hidden layer size
  • The number of classes


In [18]:
n_input = 28*28
n_hidden = 100
n_output = 10

2.4. Define Weights, Biases, and Placeholder

  • Define parameters based on predefined layer size
  • Initialize with normal distribution with $\mu = 0$ and $\sigma = 0.1$
In [19]:
weights = {
    'hidden' : tf.Variable(tf.random_normal([n_input, n_hidden], stddev = 0.1)),
    'output' : tf.Variable(tf.random_normal([n_hidden, n_output], stddev = 0.1))
}

biases = {
    'hidden' : tf.Variable(tf.random_normal([n_hidden], stddev = 0.1)),
    'output' : tf.Variable(tf.random_normal([n_output], stddev = 0.1))
}
In [20]:
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])

2.5. Build a Model

First, the layer performs several matrix multiplication to produce a set of linear activations



$$y_j = \left(\sum\limits_i \omega_{ij}x_i\right) + b_j$$$$\mathcal{y} = \omega^T \mathcal{x} + \mathcal{b}$$


Second, each linear activation is running through a nonlinear activation function




Third, predict values with an affine transformation



In [21]:
# Define Network
def build_model(x, weights, biases):
    
    # first hidden layer
    hidden = tf.add(tf.matmul(x, weights['hidden']), biases['hidden'])
    # non-linear activate function
    hidden = tf.nn.relu(hidden)
    
    # Output layer 
    output = tf.add(tf.matmul(hidden, weights['output']), biases['output'])
    
    return output

2.6. Define Loss and Optimizer

Loss

  • This defines how we measure how accurate the model is during training. As was covered in lecture, during training we want to minimize this function, which will "steer" the model in the right direction.
  • Classification: Cross entropy
    • Equivalent to apply logistic regression
$$ -\frac{1}{m}\sum_{i=1}^{m}y^{(i)}\log(h_{\theta}\left(x^{(i)}\right)) + (1-y^{(i)})\log(1-h_{\theta}\left(x^{(i)}\right)) $$

Optimizer

  • This defines how the model is updated based on the data it sees and its loss function.
  • AdamOptimizer: the most popular optimizer
In [22]:
# Define Loss
pred = build_model(x, weights, biases)
loss = tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y)
loss = tf.reduce_mean(loss)

LR = 0.0001
optm = tf.train.AdamOptimizer(LR).minimize(loss)
WARNING:tensorflow:From <ipython-input-22-37eee6b079c7>:3: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.

2.7. Define Optimization Configuration and Then Optimize




  • Define parameters for training ANN
    • n_batch: batch size for mini-batch gradient descent
    • n_iter: the number of iteration steps
    • n_prt: check loss for every n_prt iteration
  • Metrics
    • Here we can define metrics used to monitor the training and testing steps. In this example, we'll look at the accuracy, the fraction of the images that are correctly classified.

Initializer

  • Initialize all the variables
In [23]:
n_batch = 50     # Batch Size
n_iter = 5000    # Learning Iteration
n_prt = 250      # Print Cycle
In [24]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

loss_record_train = []
loss_record_test = []
for epoch in range(n_iter):
    train_x, train_y = mnist.train.next_batch(n_batch)
    sess.run(optm, feed_dict = {x: train_x, y: train_y}) 
    
    if epoch % n_prt == 0:        
        test_x, test_y = mnist.test.next_batch(n_batch)
        c1 = sess.run(loss, feed_dict = {x: train_x, y: train_y})
        c2 = sess.run(loss, feed_dict = {x: test_x, y: test_y})
        loss_record_train.append(c1)
        loss_record_test.append(c2)
        print ("Iter : {}".format(epoch))
        print ("Cost : {}".format(c1))
        
plt.figure(figsize=(10,8))
plt.plot(np.arange(len(loss_record_train))*n_prt, 
         loss_record_train, label = 'training')
plt.plot(np.arange(len(loss_record_test))*n_prt, 
         loss_record_test, label = 'testing')
plt.xlabel('iteration', fontsize = 15)
plt.ylabel('loss', fontsize = 15)
plt.legend(fontsize = 12)
plt.ylim([0, np.max(loss_record_train)])
plt.show()
Iter : 0
Cost : 2.480926036834717
Iter : 250
Cost : 1.3548682928085327
Iter : 500
Cost : 0.7710088491439819
Iter : 750
Cost : 0.48300960659980774
Iter : 1000
Cost : 0.40704935789108276
Iter : 1250
Cost : 0.565524697303772
Iter : 1500
Cost : 0.3561488091945648
Iter : 1750
Cost : 0.3473952114582062
Iter : 2000
Cost : 0.3023255169391632
Iter : 2250
Cost : 0.36950623989105225
Iter : 2500
Cost : 0.3390611410140991
Iter : 2750
Cost : 0.3183514475822449
Iter : 3000
Cost : 0.1827264279127121
Iter : 3250
Cost : 0.16383935511112213
Iter : 3500
Cost : 0.2471444457769394
Iter : 3750
Cost : 0.21442127227783203
Iter : 4000
Cost : 0.3387083411216736
Iter : 4250
Cost : 0.425953209400177
Iter : 4500
Cost : 0.3077552914619446
Iter : 4750
Cost : 0.20078086853027344

2.8. Test or Evaluate

In [25]:
test_x, test_y = mnist.test.next_batch(100)

my_pred = sess.run(pred, feed_dict = {x : test_x})
my_pred = np.argmax(my_pred, axis = 1)

labels = np.argmax(test_y, axis = 1)

accr = np.mean(np.equal(my_pred, labels))
print("Accuracy : {}%".format(accr*100))
Accuracy : 91.0%
In [26]:
test_x, test_y = mnist.test.next_batch(1)
logits = sess.run(tf.nn.softmax(pred), feed_dict = {x : test_x})
predict = np.argmax(logits)

plt.figure(figsize = (6,6))
plt.imshow(test_x.reshape(28,28), 'gray')
plt.xticks([])
plt.yticks([])
plt.show()

print('Prediction : {}'.format(predict))
np.set_printoptions(precision = 2, suppress = True)
print('Probability : {}'.format(logits.ravel()))
Prediction : 5
Probability : [ 0.    0.    0.    0.01  0.    0.98  0.    0.    0.    0.  ]

You may observe that the accuracy on the test dataset is a little lower than the accuracy on the training dataset. This gap between training accuracy and test accuracy is an example of overfitting, when a machine learning model performs worse on new data than on its training data.

What is the highest accuracy you can achieve with this first fully connected model? Since the handwritten digit classification task is pretty straightforward, you may be wondering how we can do better...

$\Rightarrow$ As we saw in lecture, convolutional neural networks (CNNs) are particularly well-suited for a variety of tasks in computer vision, and have achieved near-perfect accuracies on the MNIST dataset. We will build a CNN and ultimately output a probability distribution over the 10 digit classes (0-9) in the next lectures.

In [27]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')